{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "A Gentle Introduction to ``torch.autograd``\n",
    "---------------------------------\n",
    "\n",
    "``torch.autograd`` is PyTorch’s automatic differentiation engine that powers\n",
    "neural network training. In this section, you will get a conceptual\n",
    "understanding of how autograd helps a neural network train.\n",
    "\n",
    "Background\n",
    "~~~~~~~~~~\n",
    "Neural networks (NNs) are a collection of nested functions that are\n",
    "executed on some input data. These functions are defined by *parameters*\n",
    "(consisting of weights and biases), which in PyTorch are stored in\n",
    "tensors.\n",
    "\n",
    "Training a NN happens in two steps:\n",
    "\n",
    "**Forward Propagation**: In forward prop, the NN makes its best guess\n",
    "about the correct output. It runs the input data through each of its\n",
    "functions to make this guess.\n",
    "\n",
    "**Backward Propagation**: In backprop, the NN adjusts its parameters\n",
    "proportionate to the error in its guess. It does this by traversing\n",
    "backwards from the output, collecting the derivatives of the error with\n",
    "respect to the parameters of the functions (*gradients*), and optimizing\n",
    "the parameters using gradient descent. For a more detailed walkthrough\n",
    "of backprop, check out this `video from\n",
    "3Blue1Brown <https://www.youtube.com/watch?v=tIeHLnjs5U8>`__.\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "Usage in PyTorch\n",
    "~~~~~~~~~~~\n",
    "Let's take a look at a single training step.\n",
    "For this example, we load a pretrained resnet18 model from ``torchvision``.\n",
    "We create a random data tensor to represent a single image with 3 channels, and height & width of 64,\n",
    "and its corresponding ``label`` initialized to some random values.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading: \"https://download.pytorch.org/models/resnet18-5c106cde.pth\" to C:\\Users\\Administrator/.cache\\torch\\hub\\checkpoints\\resnet18-5c106cde.pth\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d3d5f564641d4fdb83e0b1978fe8c6df",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=46827520.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import torch, torchvision\n",
    "model = torchvision.models.resnet18(pretrained=True)\n",
    "data = torch.rand(1, 3, 64, 64)  # 1:batch size  3: channels 64*64: image size\n",
    "labels = torch.rand(1, 1000)  # 一个1000位的向量，因为resnet原本是训练一个1000类的分类器"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we run the input data through the model through each of its layers to make a prediction.\n",
    "This is the **forward pass**.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 1000])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prediction = model(data) # forward pass\n",
    "prediction.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use the model's prediction and the corresponding label to calculate the error (``loss``).\n",
    "The next step is to backpropagate this error through the network.\n",
    "Backward propagation is kicked off when we call ``.backward()`` on the error tensor.\n",
    "Autograd then calculates and stores the gradients for each model parameter in the parameter's ``.grad`` attribute.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "loss = (prediction - labels).sum()\n",
    "loss.backward() # backward pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we load an optimizer, in this case SGD with a learning rate of 0.01 and momentum of 0.9.\n",
    "We register all the parameters of the model in the optimizer.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "optim = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we call ``.step()`` to initiate gradient descent. The optimizer adjusts each parameter by its gradient stored in ``.grad``.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "optim.step() #gradient descent"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 总结一下上面的步骤：\n",
    "- 首先得到prediction，然后根据label来计算loss\n",
    "- 然后loss进行`.backward()`操作。这一步操作完，每个参数都会把自己的`.grad`属性给更新\n",
    "- 再注册SGD，然后让优化器进行`.step()`，这时优化器就会根据每个参数已有的`.grad`属性进行更新。这么，一步梯度下降更新就完成了。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "At this point, you have everything you need to train your neural network.\n",
    "The below sections detail the workings of autograd - feel free to skip them.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "--------------\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Differentiation in Autograd\n",
    "~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
    "Let's take a look at how ``autograd`` collects gradients. We create two tensors ``a`` and ``b`` with\n",
    "``requires_grad=True``. This signals to ``autograd`` that every operation on them should be tracked.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "a = torch.tensor([2., 3.], requires_grad=True)\n",
    "b = torch.tensor([6., 4.], requires_grad=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We create another tensor ``Q`` from ``a`` and ``b``.\n",
    "\n",
    "\\begin{align}Q = 3a^3 - b^2\\end{align}\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(53., grad_fn=<SumBackward0>)"
      ]
     },
     "execution_count": 97,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Q = 3*a**3 - b**2\n",
    "Q.sum()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's assume ``a`` and ``b`` to be parameters of an NN, and ``Q``\n",
    "to be the error. In NN training, we want gradients of the error\n",
    "w.r.t. parameters, i.e.\n",
    "\n",
    "\\begin{align}\\frac{\\partial Q}{\\partial a} = 9a^2\\end{align}\n",
    "\n",
    "\\begin{align}\\frac{\\partial Q}{\\partial b} = -2b\\end{align}\n",
    "\n",
    "\n",
    "When we call ``.backward()`` on ``Q``, autograd calculates these gradients\n",
    "and stores them in the respective tensors' ``.grad`` attribute.\n",
    "\n",
    "We need to explicitly pass a ``gradient`` argument in ``Q.backward()`` because it is a vector.\n",
    "``gradient`` is a tensor of the same shape as ``Q``, and it represents the\n",
    "gradient of Q w.r.t. itself, i.e.\n",
    "\n",
    "\\begin{align}\\frac{dQ}{dQ} = 1\\end{align}\n",
    "\n",
    "Equivalently, we can also aggregate Q into a scalar and call backward implicitly, like ``Q.sum().backward()``.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([36., 81.])\n",
      "tensor([-12.,  -8.])\n"
     ]
    }
   ],
   "source": [
    "external_grad = torch.tensor([1., 1.])\n",
    "Q.backward(gradient=external_grad)\n",
    "\n",
    "print(a.grad)\n",
    "print(b.grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Gradients are now deposited in ``a.grad`` and ``b.grad``\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([True, True])\n",
      "tensor([True, True])\n"
     ]
    }
   ],
   "source": [
    "# check if collected gradients are correct\n",
    "print(9*a**2 == a.grad)\n",
    "print(-2*b == b.grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([False, False])\n",
      "tensor([False, False])\n",
      "tensor([ 72., 162.])\n",
      "tensor([-24., -16.])\n"
     ]
    }
   ],
   "source": [
    "Q = 3*a**3 - b**2  # 重新构建一次\n",
    "\n",
    "Q.sum().backward()\n",
    "\n",
    "# check if collected gradients are correct\n",
    "print(9*a**2 == a.grad)\n",
    "print(-2*b == b.grad)\n",
    "\n",
    "print(a.grad)\n",
    "print(b.grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 这里有一个值得注意的点，关于graph的释放和梯度的累积\n",
    "先跑\n",
    "```python\n",
    "external_grad = torch.tensor([1., 1.])\n",
    "Q.backward(gradient=external_grad)\n",
    "```\n",
    "再跑\n",
    "```python\n",
    "Q.sum().backward()\n",
    "```\n",
    "\n",
    "1. 上面这个代码，如果直接运行，会报错：\n",
    "```shell\n",
    "RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. \n",
    "Specify retain_graph=True when calling backward the first time.\n",
    "```\n",
    "意思就是，Q和a，b构成的这个图，在执行第一次Q的backward的时候，就被释放了，所以没法再次backward了。\n",
    "\n",
    "但注意，图被释放，这些变量其实还在，这时依然可以打印出Q，a，b。只是Q跟a，b之间的关系没了。\n",
    "\n",
    "2. 如果单独再把`Q = 3*a**3 - b**2`运行一下，上面这代码就可以运行，但是会发现计算出来的梯度就不对了：\n",
    "```python\n",
    "print(9*a**2 == a.grad)\n",
    "print(-2*b == b.grad)\n",
    ">>>\n",
    "tensor([False, False])\n",
    "tensor([False, False])\n",
    "```\n",
    "\n",
    "这是因为**梯度被累积**了。前面a和b的grad打印出来是\n",
    "```python\n",
    "tensor([36., 81.])\n",
    "tensor([-12.,  -8.])\n",
    "```\n",
    "\n",
    "第二次backward后得到：\n",
    "```python\n",
    "tensor([ 72., 162.])\n",
    "tensor([-24., -16.])\n",
    "```\n",
    "\n",
    "可以看出是直接相加了。\n",
    "\n",
    "3. 想重复运行backward，就需要把graph重建。\n",
    "一个方法是把Q，a，b的建立重新跑一遍。这样梯度不会累加。因为相当于是新graph了。\n",
    "\n",
    "另一种方法就是在backward时添加`retain_graph=True`参数，这时梯度依然会累加。\n",
    "\n",
    "\n",
    "\n",
    "> **梯度累加**我目前还没搞懂，干啥用的？什么时候累加什么时候不要？\n",
    "\n",
    "参考一下这个：\n",
    "https://www.cnblogs.com/zjuhaohaoxuexi/p/15057530.html"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 关于求导的一些澄清\n",
    "其实这个例子举得不好，把我都搞迷糊了，看下面才看清楚。\n",
    "\n",
    "首先，求导一般只能一个标量对其他求导。这里的Q是一个向量，所以本来应该先求sum，再求导。\n",
    "\n",
    "但是这里给出另一种方式，就是通过一个`external_grad`向量，来直接对Q这个向量求导，这里`external_grad = torch.tensor([1., 1.])`的目的，就是为了达到sum的效果，即sum(Q)对Q的导数，就是这里的`external_grad`。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Optional Reading - Vector Calculus using ``autograd``\n",
    "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
    "\n",
    "Mathematically, if you have a vector valued function\n",
    "$\\vec{y}=f(\\vec{x})$, then the gradient of $\\vec{y}$ with\n",
    "respect to $\\vec{x}$ is a Jacobian matrix $J$:\n",
    "\n",
    "\\begin{align}J\n",
    "     =\n",
    "      \\left(\\begin{array}{cc}\n",
    "      \\frac{\\partial \\bf{y}}{\\partial x_{1}} &\n",
    "      ... &\n",
    "      \\frac{\\partial \\bf{y}}{\\partial x_{n}}\n",
    "      \\end{array}\\right)\n",
    "     =\n",
    "     \\left(\\begin{array}{ccc}\n",
    "      \\frac{\\partial y_{1}}{\\partial x_{1}} & \\cdots & \\frac{\\partial y_{1}}{\\partial x_{n}}\\\\\n",
    "      \\vdots & \\ddots & \\vdots\\\\\n",
    "      \\frac{\\partial y_{m}}{\\partial x_{1}} & \\cdots & \\frac{\\partial y_{m}}{\\partial x_{n}}\n",
    "      \\end{array}\\right)\\end{align}\n",
    "\n",
    "Generally speaking, ``torch.autograd`` is an engine for computing\n",
    "vector-Jacobian product. That is, given any vector $\\vec{v}$, compute the product\n",
    "$J^{T}\\cdot \\vec{v}$\n",
    "\n",
    "If $\\vec{v}$ happens to be the gradient of a scalar function $l=g\\left(\\vec{y}\\right)$:\n",
    "\n",
    "\\begin{align}\\vec{v}\n",
    "   =\n",
    "   \\left(\\begin{array}{ccc}\\frac{\\partial l}{\\partial y_{1}} & \\cdots & \\frac{\\partial l}{\\partial y_{m}}\\end{array}\\right)^{T}\\end{align}\n",
    "\n",
    "then by the chain rule, the vector-Jacobian product would be the\n",
    "gradient of $l$ with respect to $\\vec{x}$:\n",
    "\n",
    "\\begin{align}J^{T}\\cdot \\vec{v}=\\left(\\begin{array}{ccc}\n",
    "      \\frac{\\partial y_{1}}{\\partial x_{1}} & \\cdots & \\frac{\\partial y_{m}}{\\partial x_{1}}\\\\\n",
    "      \\vdots & \\ddots & \\vdots\\\\\n",
    "      \\frac{\\partial y_{1}}{\\partial x_{n}} & \\cdots & \\frac{\\partial y_{m}}{\\partial x_{n}}\n",
    "      \\end{array}\\right)\\left(\\begin{array}{c}\n",
    "      \\frac{\\partial l}{\\partial y_{1}}\\\\\n",
    "      \\vdots\\\\\n",
    "      \\frac{\\partial l}{\\partial y_{m}}\n",
    "      \\end{array}\\right)=\\left(\\begin{array}{c}\n",
    "      \\frac{\\partial l}{\\partial x_{1}}\\\\\n",
    "      \\vdots\\\\\n",
    "      \\frac{\\partial l}{\\partial x_{n}}\n",
    "      \\end{array}\\right)\\end{align}\n",
    "\n",
    "This characteristic of vector-Jacobian product is what we use in the above example;\n",
    "``external_grad`` represents $\\vec{v}$.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Computational Graph\n",
    "~~~~~~~~~~~~~~~~~~~\n",
    "\n",
    "Conceptually, autograd keeps a record of data (tensors) & all executed\n",
    "operations (along with the resulting new tensors) in a directed acyclic\n",
    "graph (DAG) consisting of\n",
    "`Function <https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function>`__\n",
    "objects. In this DAG, leaves are the input tensors, roots are the output\n",
    "tensors. By tracing this graph from roots to leaves, you can\n",
    "automatically compute the gradients using the chain rule.\n",
    "\n",
    "In a forward pass, autograd does two things simultaneously:\n",
    "\n",
    "- run the requested operation to compute a resulting tensor, and\n",
    "- maintain the operation’s *gradient function* in the DAG.\n",
    "\n",
    "The backward pass kicks off when ``.backward()`` is called on the DAG\n",
    "root. ``autograd`` then:\n",
    "\n",
    "- computes the gradients from each ``.grad_fn``,\n",
    "- accumulates them in the respective tensor’s ``.grad`` attribute, and\n",
    "- using the chain rule, propagates all the way to the leaf tensors.\n",
    "\n",
    "Below is a visual representation of the DAG in our example. In the graph,\n",
    "the arrows are in the direction of the forward pass. The nodes represent the backward functions\n",
    "of each operation in the forward pass. The leaf nodes in blue represent our leaf tensors ``a`` and ``b``.\n",
    "\n",
    ".. figure:: /_static/img/dag_autograd.png\n",
    "\n",
    "<div class=\"alert alert-info\"><h4>Note</h4><p>**DAGs are dynamic in PyTorch**\n",
    "  An important thing to note is that the graph is recreated from scratch; after each\n",
    "  ``.backward()`` call, autograd starts populating a new graph. This is\n",
    "  exactly what allows you to use control flow statements in your model;\n",
    "  you can change the shape, size and operations at every iteration if\n",
    "  needed.</p></div>\n",
    "\n",
    "Exclusion from the DAG\n",
    "^^^^^^^^^^^^^^^^^^^^^^\n",
    "\n",
    "``torch.autograd`` tracks operations on all tensors which have their\n",
    "``requires_grad`` flag set to ``True``. For tensors that don’t require\n",
    "gradients, setting this attribute to ``False`` excludes it from the\n",
    "gradient computation DAG.\n",
    "\n",
    "The output tensor of an operation will require gradients even if only a\n",
    "single input tensor has ``requires_grad=True``.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "x = torch.rand(5, 5)\n",
    "y = torch.rand(5, 5)\n",
    "z = torch.rand((5, 5), requires_grad=True)\n",
    "\n",
    "a = x + y\n",
    "print(f\"Does `a` require gradients? : {a.requires_grad}\")\n",
    "b = x + z\n",
    "print(f\"Does `b` require gradients?: {b.requires_grad}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In a NN, parameters that don't compute gradients are usually called **frozen parameters**.\n",
    "It is useful to \"freeze\" part of your model if you know in advance that you won't need the gradients of those parameters\n",
    "(this offers some performance benefits by reducing autograd computations).\n",
    "\n",
    "Another common usecase where exclusion from the DAG is important is for\n",
    "`finetuning a pretrained network <https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html>`__\n",
    "\n",
    "In finetuning, we freeze most of the model and typically only modify the classifier layers to make predictions on new labels.\n",
    "Let's walk through a small example to demonstrate this. As before, we load a pretrained resnet18 model, and freeze all the parameters.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "from torch import nn, optim\n",
    "\n",
    "model = torchvision.models.resnet18(pretrained=True)\n",
    "\n",
    "# Freeze all the parameters in the network\n",
    "for param in model.parameters():\n",
    "    param.requires_grad = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's say we want to finetune the model on a new dataset with 10 labels.\n",
    "In resnet, the classifier is the last linear layer ``model.fc``.\n",
    "We can simply replace it with a new linear layer (unfrozen by default)\n",
    "that acts as our classifier.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "model.fc = nn.Linear(512, 10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now all parameters in the model, except the parameters of ``model.fc``, are frozen.\n",
    "The only parameters that compute gradients are the weights and bias of ``model.fc``.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "# Optimize only the classifier\n",
    "optimizer = optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice although we register all the parameters in the optimizer,\n",
    "the only parameters that are computing gradients (and hence updated in gradient descent)\n",
    "are the weights and bias of the classifier.\n",
    "\n",
    "The same exclusionary functionality is available as a context manager in\n",
    "`torch.no_grad() <https://pytorch.org/docs/stable/generated/torch.no_grad.html>`__\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "--------------\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Further readings:\n",
    "~~~~~~~~~~~~~~~~~~~\n",
    "\n",
    "-  `In-place operations & Multithreaded Autograd <https://pytorch.org/docs/stable/notes/autograd.html>`__\n",
    "-  `Example implementation of reverse-mode autodiff <https://colab.research.google.com/drive/1VpeE6UvEPRz9HmsHh1KS0XxXjYu533EC>`__\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fine-tuning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn, optim\n",
    "\n",
    "model = torchvision.models.resnet18(pretrained=True)\n",
    "\n",
    "# Freeze all the parameters in the network\n",
    "for param in model.parameters():\n",
    "    param.requires_grad = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Let’s say we want to finetune the model on a new dataset with 10 labels. In resnet, the classifier is the last linear layer model.fc. \n",
    "We can simply replace it with a new linear layer (unfrozen by default) that acts as our classifier.\n",
    "\"\"\"\n",
    "model.fc = nn.Linear(512, 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "OrderedDict([('conv1',\n",
       "              Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)),\n",
       "             ('bn1',\n",
       "              BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),\n",
       "             ('relu', ReLU(inplace=True)),\n",
       "             ('maxpool',\n",
       "              MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)),\n",
       "             ('layer1',\n",
       "              Sequential(\n",
       "                (0): BasicBlock(\n",
       "                  (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "                  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "                  (relu): ReLU(inplace=True)\n",
       "                  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "                  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "                )\n",
       "                (1): BasicBlock(\n",
       "                  (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "                  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "                  (relu): ReLU(inplace=True)\n",
       "                  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "                  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "                )\n",
       "              )),\n",
       "             ('layer2',\n",
       "              Sequential(\n",
       "                (0): BasicBlock(\n",
       "                  (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "                  (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "                  (relu): ReLU(inplace=True)\n",
       "                  (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "                  (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "                  (downsample): Sequential(\n",
       "                    (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "                    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "                  )\n",
       "                )\n",
       "                (1): BasicBlock(\n",
       "                  (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "                  (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "                  (relu): ReLU(inplace=True)\n",
       "                  (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "                  (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "                )\n",
       "              )),\n",
       "             ('layer3',\n",
       "              Sequential(\n",
       "                (0): BasicBlock(\n",
       "                  (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "                  (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "                  (relu): ReLU(inplace=True)\n",
       "                  (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "                  (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "                  (downsample): Sequential(\n",
       "                    (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "                    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "                  )\n",
       "                )\n",
       "                (1): BasicBlock(\n",
       "                  (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "                  (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "                  (relu): ReLU(inplace=True)\n",
       "                  (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "                  (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "                )\n",
       "              )),\n",
       "             ('layer4',\n",
       "              Sequential(\n",
       "                (0): BasicBlock(\n",
       "                  (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "                  (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "                  (relu): ReLU(inplace=True)\n",
       "                  (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "                  (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "                  (downsample): Sequential(\n",
       "                    (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "                    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "                  )\n",
       "                )\n",
       "                (1): BasicBlock(\n",
       "                  (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "                  (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "                  (relu): ReLU(inplace=True)\n",
       "                  (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "                  (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "                )\n",
       "              )),\n",
       "             ('avgpool', AdaptiveAvgPool2d(output_size=(1, 1))),\n",
       "             ('fc', Linear(in_features=512, out_features=10, bias=True))])"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model._modules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Optimize only the classifier\n",
    "optimizer = optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 该教程完整版：\n",
    "https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
